{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploying pre-trained PyTorch vision models with Amazon SageMaker Neo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Amazon SageMaker Neo is API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import ResNet18 from TorchVision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll import [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.models as models\n",
    "import tarfile\n",
    "\n",
    "resnet18 = models.resnet18(pretrained=True)\n",
    "torch.save(resnet18, 'model.pth')\n",
    "\n",
    "with tarfile.open('model.tar.gz', 'w:gz') as f:\n",
    "    f.add('model.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Invoke Neo Compilation API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then forward the model artifact to Neo Compilation API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import time\n",
    "from sagemaker.utils import name_from_base\n",
    "\n",
    "role = sagemaker.get_execution_role()\n",
    "sess = sagemaker.Session()\n",
    "region = sess.boto_region_name\n",
    "bucket = sess.default_bucket()\n",
    "\n",
    "compilation_job_name = name_from_base('TorchVision-ResNet18-Neo')\n",
    "\n",
    "model_key = '{}/model/model.tar.gz'.format(compilation_job_name)\n",
    "model_path = 's3://{}/{}'.format(bucket, model_key)\n",
    "boto3.resource('s3').Bucket(bucket).upload_file('model.tar.gz', model_key)\n",
    "\n",
    "sm_client = boto3.client('sagemaker')\n",
    "data_shape = '{\"input0\":[1,3,224,224]}'\n",
    "target_device = 'ml_c5'\n",
    "framework = 'PYTORCH'\n",
    "framework_version = '0.4.0'\n",
    "compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = sm_client.create_compilation_job(\n",
    "    CompilationJobName=compilation_job_name,\n",
    "    RoleArn=role,\n",
    "    InputConfig={\n",
    "        'S3Uri': model_path,\n",
    "        'DataInputConfig': data_shape,\n",
    "        'Framework': framework\n",
    "    },\n",
    "    OutputConfig={\n",
    "        'S3OutputLocation': compiled_model_path,\n",
    "        'TargetDevice': target_device\n",
    "    },\n",
    "    StoppingCondition={\n",
    "        'MaxRuntimeInSeconds': 300\n",
    "    }\n",
    ")\n",
    "print(response)\n",
    "\n",
    "# Poll every 30 sec\n",
    "while True:\n",
    "    response = sm_client.describe_compilation_job(CompilationJobName=compilation_job_name)\n",
    "    if response['CompilationJobStatus'] == 'COMPLETED':\n",
    "        break\n",
    "    elif response['CompilationJobStatus'] == 'FAILED':\n",
    "        raise RuntimeError('Compilation failed')\n",
    "    print('Compiling ...')\n",
    "    time.sleep(30)\n",
    "print('Done!')\n",
    "\n",
    "# Extract compiled model artifact\n",
    "compiled_model_path = response['ModelArtifacts']['S3ModelArtifacts']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create prediction endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To create a prediction endpoint, we first specify two additional functions, to be used with Neo Deep Learning Runtime:\n",
    "\n",
    "* `neo_preprocess(payload, content_type)`: Function that takes in the payload and Content-Type of each incoming request and returns a NumPy array. Here, the payload is byte-encoded NumPy array, so the function simply decodes the bytes to obtain the NumPy array.\n",
    "* `neo_postprocess(result)`: Function that takes the prediction results produced by Deep Learining Runtime and returns the response body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pygmentize resnet18.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Upload the Python script containing the two functions to S3:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_key = '{}/source/sourcedir.tar.gz'.format(compilation_job_name)\n",
    "source_path = 's3://{}/{}'.format(bucket, source_key)\n",
    "\n",
    "with tarfile.open('sourcedir.tar.gz', 'w:gz') as f:\n",
    "    f.add('resnet18.py')\n",
    "\n",
    "boto3.resource('s3').Bucket(bucket).upload_file('sourcedir.tar.gz', source_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then create a SageMaker model record:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.model import NEO_IMAGE_ACCOUNT\n",
    "from sagemaker.fw_utils import create_image_uri\n",
    "\n",
    "model_name = name_from_base('TorchVision-ResNet18-Neo')\n",
    "\n",
    "image_uri = create_image_uri(region, 'neo-' + framework.lower(), target_device.replace('_', '.'),\n",
    "                             framework_version, py_version='py3', account=NEO_IMAGE_ACCOUNT[region])\n",
    "\n",
    "response = sm_client.create_model(\n",
    "    ModelName=model_name,\n",
    "    PrimaryContainer={\n",
    "        'Image': image_uri,\n",
    "        'ModelDataUrl': compiled_model_path,\n",
    "        'Environment': { 'SAGEMAKER_SUBMIT_DIRECTORY': source_path }\n",
    "    },\n",
    "    ExecutionRoleArn=role\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we create an Endpoint Configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_name = model_name\n",
    "\n",
    "response = sm_client.create_endpoint_config(\n",
    "    EndpointConfigName=config_name,\n",
    "    ProductionVariants=[\n",
    "        {\n",
    "            'VariantName': 'default-variant-name',\n",
    "            'ModelName': model_name,\n",
    "            'InitialInstanceCount': 1,\n",
    "            'InstanceType': 'ml.c5.xlarge',\n",
    "            'InitialVariantWeight': 1.0\n",
    "        },\n",
    "    ],\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we create an Endpoint:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint_name = model_name + '-Endpoint'\n",
    "\n",
    "response = sm_client.create_endpoint(\n",
    "    EndpointName=endpoint_name,\n",
    "    EndpointConfigName=config_name,\n",
    ")\n",
    "print(response)\n",
    "\n",
    "print('Creating endpoint ...')\n",
    "waiter = sm_client.get_waiter('endpoint_in_service')\n",
    "waiter.wait(EndpointName=endpoint_name)\n",
    "\n",
    "response = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Send requests"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to send a cat picture.\n",
    "\n",
    "![title](cat.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "sm_runtime = boto3.Session().client('sagemaker-runtime')\n",
    "\n",
    "with open('cat.jpg', 'rb') as f:\n",
    "    payload = f.read()\n",
    "response = sm_runtime.invoke_endpoint(EndpointName=endpoint_name,\n",
    "                                      ContentType='application/x-image',\n",
    "                                      Body=payload)\n",
    "print(response)\n",
    "result = json.loads(response['Body'].read().decode())\n",
    "print('Most likely class: {}'.format(np.argmax(result)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p36",
   "language": "python",
   "name": "conda_pytorch_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
